{
 "cells": [
  {
   "cell_type": "raw",
   "id": "da299444",
   "metadata": {},
   "source": [
    "Copyright 2023 Google LLC\n",
    "\n",
    "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "you may not use this file except in compliance with the License.\n",
    "You may obtain a copy of the License at\n",
    "\n",
    "    http://www.apache.org/licenses/LICENSE-2.0\n",
    "\n",
    "Unless required by applicable law or agreed to in writing, software\n",
    "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "See the License for the specific language governing permissions and\n",
    "limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9d0817cd-9e64-4d5e-9c66-4e0961aa1085",
   "metadata": {},
   "source": [
    "# MLOps End to End Workflow (I)\n",
    "\n",
    "Implementation of an end-to-end ML Ops workflow for the use case to detect fraudulent credit card transactions, see [Kaggle dataset](https://www.kaggle.com/datasets/mlg-ulb/creditcardfraud).\n",
    "\n",
    "This set of notebooks cover:\n",
    "\n",
    "[Experimentation](01-experimentation.ipynb):\n",
    "1. Set up: Creation of the Vertex Dataset, extraction of the schema\n",
    "2. Implementation of a TFX pipeline\n",
    "\n",
    "[CICD](02-cicd.ipynb):\n",
    "\n",
    "3. Deployment of the Vertex AI Pipeline through a CI/CD process\n",
    "4. Deployment of a Continuous Training pipeline that can be triggered via Pub/Sub and produces a model in the Model Registry\n",
    "5. Deployment of the Inference Pipeline consisting of a Cloud Function that retrieves features from Feature Store and calls the model on a Vertex AI Endpoint\n",
    "6. Deployment of the model to a Vertex AI Endpoint through a CI/CD process\n",
    "\n",
    "[Prediction](03-prediction.ipynb):\n",
    "\n",
    "7. Deploy the model to an endpoint\n",
    "8. Create a test prediction\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d62219f7-f6fe-476e-bf88-42db7978460c",
   "metadata": {
    "tags": []
   },
   "source": [
    "## Setup\n",
    "\n",
    "### Package installations\n",
    "\n",
    "Run \n",
    "\n",
    "```\n",
    "pip install --user -r requirements.txt\n",
    "```\n",
    "\n",
    "If an error occurs later, upon importing `tensorflow` related to `numpy`, this can be fixed with a forced reinstall of numpy:\n",
    "\n",
    "```\n",
    "pip install --user numpy==1.21.6 --force-reinstall\n",
    "```\n",
    "\n",
    "Moreover, running the end-to-end unit test of the pipeline requires forcing a re-install of `grpcio` and `grpcio-tools`:\n",
    "\n",
    "```\n",
    "pip install --user grpcio --no-binary=grpcio --force-reinstall\n",
    "pip install --user grpcio-tools --no-binary=grpcio-tools --force-reinstall\n",
    "```\n",
    "\n",
    "### Setup config\n",
    "Edit the file `mainconfig.yaml` and use the adequate configuration parameters for your environment.\n",
    "\n",
    "### Load data into BigQuery\n",
    "\n",
    "We are using the Credit Card Fraud Detection dataset from the Machine Learning Group at the Universite Libre de Bruxelles, available on [Kaggle](https://www.kaggle.com/datasets/mlg-ulb/creditcardfraud).\n",
    "\n",
    "Create a dataset called `creditcards` in the `EU` region.\n",
    "\n",
    "To load the data into BQ, from Cloud Shell:\n",
    "\n",
    "```\n",
    "bq load --skip_leading_rows=1 creditcards.creditcards gs://cxb1-prj-test-no-vpcsc/csv/creditcard.csv Time:STRING,V1:FLOAT,V2:FLOAT,V3:FLOAT,V4:FLOAT,V5:FLOAT,V6:FLOAT,V7:FLOAT,V8:FLOAT,V9:FLOAT,V10:FLOAT,V11:FLOAT,V12:FLOAT,V13:FLOAT,V14:FLOAT,V15:FLOAT,V16:FLOAT,V17:FLOAT,V18:FLOAT,V19:FLOAT,V20:FLOAT,V21:FLOAT,V22:FLOAT,V23:FLOAT,V24:FLOAT,V25:FLOAT,V26:FLOAT,V27:FLOAT,V28:FLOAT,Amount:FLOAT,Class:STRING\n",
    "```\n",
    "\n",
    "### Import Libraries\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2b997d6a-f8fd-43f5-bafd-81b169965160",
   "metadata": {},
   "outputs": [],
   "source": [
    "#%load_ext autoreload\n",
    "#%autoreload 2\n",
    "\n",
    "import os\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "import tensorflow_data_validation as tfdv\n",
    "from google.cloud import bigquery\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "from google.cloud import aiplatform as vertex_ai\n",
    "\n",
    "import yaml\n",
    "import os"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3bdeae6-b5c9-479d-8318-627959bf9898",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('mainconfig.yaml') as f:\n",
    "    main_config = yaml.safe_load(f)\n",
    "\n",
    "# select your config    \n",
    "main_config = main_config['creditcards']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f5c0f8ad-ca30-4f5c-a135-809409f58abd",
   "metadata": {},
   "source": [
    "### Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d3b2ad96-7e3a-4980-a8d2-ee78993f0732",
   "metadata": {},
   "outputs": [],
   "source": [
    "PROJECT = main_config['project'] \n",
    "REGION = main_config['region'] \n",
    "DOCKER_REPO = main_config['docker_repo']\n",
    "\n",
    "SERVICE_ACCOUNT = main_config['service_account']\n",
    "\n",
    "# BigQuery and data locations\n",
    "\n",
    "BQ_SOURCE_TABLE= main_config['bq']['source_table'] # raw input\n",
    "ML_TABLE = main_config['bq']['ml_table'] # the one we will use for the training\n",
    "\n",
    "BQ_DATASET_NAME = main_config['bq']['dataset']\n",
    "BQ_LOCATION = main_config['bq']['location'] # multiregion provides more resilience\n",
    "\n",
    "VERTEX_DATASET_NAME = main_config['vertex_dataset_name']\n",
    "\n",
    "RAW_SCHEMA_DIR = main_config['raw_schema_dir']\n",
    "\n",
    "BUCKET =  main_config['bucket']\n",
    "\n",
    "# TFX and model config\n",
    "\n",
    "# model version\n",
    "VERSION = main_config['version']\n",
    "\n",
    "\n",
    "MODEL_DISPLAY_NAME = f'{VERTEX_DATASET_NAME}-classifier-{VERSION}'\n",
    "WORKSPACE = f'gs://{BUCKET}/{VERTEX_DATASET_NAME}'\n",
    "\n",
    "MLMD_SQLLITE = 'mlmd.sqllite'\n",
    "ARTIFACT_STORE = os.path.join(WORKSPACE, 'tfx_artifacts_interactive')\n",
    "MODEL_REGISTRY = os.path.join(WORKSPACE, 'model_registry')\n",
    "PIPELINE_NAME = f'{MODEL_DISPLAY_NAME}-train-pipeline'\n",
    "PIPELINE_ROOT = os.path.join(ARTIFACT_STORE, PIPELINE_NAME)\n",
    "\n",
    "ENDPOINT_DISPLAY_NAME = f'{VERTEX_DATASET_NAME}-classifier'\n",
    "\n",
    "FEATURESTORE_ID = main_config['featurestore_id']\n",
    "\n",
    "CF_REGION = main_config['cloudfunction_region']\n",
    "\n",
    "DATAFLOW_SUBNETWORK = f\"https://www.googleapis.com/compute/v1/projects/{PROJECT}/regions/{REGION}/subnetworks/{main_config['dataflow']['subnet']}\"\n",
    "DATAFLOW_SERVICE_ACCOUNT = main_config['dataflow']['service_account']\n",
    "\n",
    "CLOUDBUILD_SA = f'projects/{PROJECT}/serviceAccounts/{SERVICE_ACCOUNT}'\n",
    "\n",
    "LIMIT=main_config['limit']"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "274eaeb4-a9db-42f6-a19d-bc9567fa5b37",
   "metadata": {},
   "source": [
    "#### Generate pip configuration\n",
    "\n",
    "Our containers should install packages from our internal Artifact Registry, rather than try to go to the public PyPI."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69a752e4-275c-464d-a842-c3feae9e5355",
   "metadata": {},
   "outputs": [],
   "source": [
    "pip_conf = f'''[global]\n",
    "index-url = https://{main_config['artifactregistry_region']}-python.pkg.dev/{PROJECT}/{main_config['python_pkg_repo']}/simple/\n",
    "timeout=10\n",
    "'''\n",
    "\n",
    "with open('build/pip.conf', 'w') as f:\n",
    "    f.write(pip_conf)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fda3ab18-7560-4d47-a911-101d21e8bfd7",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Project ID:\", PROJECT)\n",
    "print(\"Region:\", REGION)\n",
    "\n",
    "vertex_ai.init(\n",
    "    project=PROJECT,\n",
    "    location=REGION\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bb98a921-cba7-46ee-baa7-b0a4461c3b2c",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ[\"VERTEX_DATASET_NAME\"] = VERTEX_DATASET_NAME\n",
    "os.environ[\"MODEL_DISPLAY_NAME\"] = MODEL_DISPLAY_NAME\n",
    "os.environ[\"PIPELINE_NAME\"] = PIPELINE_NAME\n",
    "os.environ[\"PROJECT\"] = PROJECT\n",
    "os.environ['GOOGLE_CLOUD_PROJECT'] = PROJECT\n",
    "os.environ[\"REGION\"] = REGION\n",
    "os.environ[\"GCS_LOCATION\"] = f\"gs://{BUCKET}/{VERTEX_DATASET_NAME}\"\n",
    "os.environ[\"MODEL_REGISTRY_URI\"] = os.path.join(os.environ[\"GCS_LOCATION\"], \"model_registry\")\n",
    "os.environ[\"TRAIN_LIMIT\"] = \"85000\"\n",
    "os.environ[\"TEST_LIMIT\"] = \"15000\"\n",
    "os.environ[\"BEAM_RUNNER\"] = \"DataflowRunner\"\n",
    "os.environ[\"TRAINING_RUNNER\"] = \"vertex\"\n",
    "os.environ[\"TFX_IMAGE_URI\"] = f\"{DOCKER_REPO}/vertex:latest\"\n",
    "os.environ[\"ENABLE_CACHE\"] = \"1\"\n",
    "os.environ[\"SUBNETWORK\"] = DATAFLOW_SUBNETWORK\n",
    "os.environ[\"SERVICE_ACCOUNT\"] = DATAFLOW_SERVICE_ACCOUNT\n",
    "os.environ[\"BQ_LOCATION\"] = BQ_LOCATION\n",
    "os.environ[\"BQ_DATASET_NAME\"] = BQ_DATASET_NAME\n",
    "os.environ[\"ML_TABLE\"] = ML_TABLE\n",
    "os.environ[\"GCS_LOCATION\"] = f\"gs://{BUCKET}/{VERTEX_DATASET_NAME}/e2e_tests\"\n",
    "os.environ[\"SUBNETWORK\"] = DATAFLOW_SUBNETWORK"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6febb4b2-d832-4526-afde-30266bd725cc",
   "metadata": {},
   "source": [
    "# Generate ML data\n",
    "\n",
    "We add a `ML_use` column for pre-splitting the data, where 80% of the datsa items are set to `UNASSIGNED` while the other 20% is set to `TEST`.\n",
    "This column is used during training to split the dataset for training and test.\n",
    "\n",
    "In the training phase, the `UNASSIGNED` are split into `train` and `eval`. The `TEST` split is will be used for the final model validation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "55b9ffe2-d80a-4149-8de7-6a9ef82c85d3",
   "metadata": {},
   "outputs": [],
   "source": [
    "sql_script = f'''\n",
    "CREATE OR REPLACE TABLE `{PROJECT}.{BQ_DATASET_NAME}.{ML_TABLE}` \n",
    "AS (\n",
    "    SELECT\n",
    "      * EXCEPT(Class),\n",
    "      CAST(Class AS FLOAT64) as Class,\n",
    "      IF(ABS(MOD(FARM_FINGERPRINT(Time), 100)) <= 80, 'UNASSIGNED', 'TEST') AS ML_use\n",
    "    FROM\n",
    "      `{PROJECT}.{BQ_DATASET_NAME}.{BQ_SOURCE_TABLE}`\n",
    ")\n",
    "'''\n",
    "\n",
    "bq_client = bigquery.Client(project=PROJECT, location=BQ_LOCATION)\n",
    "job = bq_client.query(sql_script)\n",
    "#job.result()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d74b54f8-9509-465c-919f-8c02c3a1aa6a",
   "metadata": {},
   "source": [
    "# Data Exploration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0391720d-1aff-47fb-a0c6-8e46323f4d79",
   "metadata": {},
   "outputs": [],
   "source": [
    "from google.cloud import bigquery\n",
    "\n",
    "client = bigquery.Client(project=PROJECT)  \n",
    "\n",
    "# I use the ML table here and I exclude the TIME and ML_USE columns, because I will later use this sample data to generate\n",
    "# the schema for the training\n",
    "sql = f\"SELECT * EXCEPT(time, ml_use) FROM `{PROJECT}.{BQ_DATASET_NAME}.{ML_TABLE}` LIMIT 1000\"\n",
    "print(sql)\n",
    "\n",
    "query_job = client.query(sql, location=BQ_LOCATION)\n",
    "sample_data = query_job.result().to_dataframe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "49686501-b357-48f2-853d-1b2672c43f7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9a520b19-87a7-482a-9de1-cdc862c12243",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bigquery counts --project {PROJECT} \n",
    "\n",
    "SELECT \n",
    "  Class, count(*) as n\n",
    "FROM `creditcards.creditcards`\n",
    "GROUP BY Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "243a9a1d-84b9-4943-ad77-df5b8cf97d59",
   "metadata": {},
   "outputs": [],
   "source": [
    "counts.plot(kind='bar', x='Class', y='n', logy=True, legend=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "539820cd-f632-477c-a7ce-88cee2ad9340",
   "metadata": {},
   "outputs": [],
   "source": [
    "sample_data.V4.hist()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9223d4a-c6eb-429f-8b17-e80a488733e5",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%bigquery --project {PROJECT}\n",
    "\n",
    "SELECT ML_use, Class, COUNT(*) as n\n",
    "FROM creditcards.creditcards_ml\n",
    "GROUP BY ML_use, Class"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7b4eea14-66df-4651-8c0e-3b035e7da22c",
   "metadata": {},
   "source": [
    "# Generate Schema\n",
    "\n",
    "\n",
    "The [TensorFlow Data Validation (TFDV)](https://www.tensorflow.org/tfx/data_validation/get_started) data schema will be used in:\n",
    "1. Identify the raw data types and shapes in the data transformation.\n",
    "2. Create the serving input signature for the custom model.\n",
    "3. Validate the new raw training data in the TFX pipeline."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b0633ec5-ab02-43fd-bcce-c32c62311d21",
   "metadata": {},
   "outputs": [],
   "source": [
    "stats = tfdv.generate_statistics_from_dataframe(\n",
    "    dataframe=sample_data,\n",
    "    stats_options=tfdv.StatsOptions(\n",
    "        label_feature='Class',\n",
    "        weight_feature=None,\n",
    "        sample_rate=1,\n",
    "        num_top_values=50\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "80d78215-3700-4065-8e78-6d9b6f1131ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "tfdv.visualize_statistics(stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "69d116ce-8fa5-491b-9ea3-00fbcd582ecc",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "schema = tfdv.infer_schema(statistics=stats)\n",
    "tfdv.display_schema(schema=schema)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "33419848-88f1-4b28-9a53-9c7c723578cc",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_schema_location = os.path.join(RAW_SCHEMA_DIR, 'schema.pbtxt')\n",
    "tfdv.write_schema_text(schema, raw_schema_location)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "326e9b0f-90d0-4b9c-a1a6-f18131dfcdf8",
   "metadata": {},
   "source": [
    "# Create Vertex Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7f596319-e9a9-4743-902f-19b5a8694eaa",
   "metadata": {},
   "outputs": [],
   "source": [
    "bq_uri = f\"bq://{PROJECT}.{BQ_DATASET_NAME}.{ML_TABLE}\"\n",
    "\n",
    "dataset = vertex_ai.TabularDataset.create(\n",
    "    display_name=VERTEX_DATASET_NAME, bq_source=bq_uri)\n",
    "\n",
    "dataset.gca_resource"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c236caba-18d4-4454-ba79-2a2b7bcbad32",
   "metadata": {},
   "source": [
    "## Retrieve and inspect the Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a6988e8a-51dd-4920-98e0-8902107f5a55",
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset = vertex_ai.TabularDataset.list(\n",
    "    filter=f\"display_name={VERTEX_DATASET_NAME}\", \n",
    "    order_by=\"update_time\")[-1]\n",
    "\n",
    "print(\"Dataset resource name:\", dataset.resource_name)\n",
    "print(\"Dataset BigQuery source:\", dataset.gca_resource.metadata['inputConfig']['bigquerySource']['uri'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca5baa2e-0789-4547-92ae-5d8d258b2375",
   "metadata": {},
   "source": [
    "# Build the TFX Pipeline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5d37c85f-410f-4222-9b22-6e5fcbaf04bd",
   "metadata": {},
   "outputs": [],
   "source": [
    "import tfx.v1 as tfx\n",
    "from tfx.extensions.google_cloud_big_query.example_gen.component import BigQueryExampleGen\n",
    "from tfx.proto import example_gen_pb2, transform_pb2\n",
    "\n",
    "import tensorflow as tf\n",
    "import tensorflow_transform as tft\n",
    "import tensorflow_data_validation as tfdv\n",
    "import tensorflow_model_analysis as tfma\n",
    "from tensorflow_transform.tf_metadata import schema_utils\n",
    "\n",
    "\n",
    "import ml_metadata as mlmd\n",
    "from ml_metadata.proto import metadata_store_pb2\n",
    "from tfx.orchestration.experimental.interactive.interactive_context import InteractiveContext\n",
    "\n",
    "import logging\n",
    "import json\n",
    "\n",
    "from src.common import features, datasource_utils\n",
    "from src.model_training import data\n",
    "from src.tfx_pipelines import components"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "11ecfb31-0458-4cfc-9afe-a04f1b044e34",
   "metadata": {},
   "outputs": [],
   "source": [
    "logging.getLogger().setLevel(logging.ERROR)\n",
    "tf.get_logger().setLevel('ERROR')\n",
    "\n",
    "print(\"TFX Version:\", tfx.__version__)\n",
    "print(\"Tensorflow Version:\", tf.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c5ec6f8-634f-47b4-85ab-756314fa8129",
   "metadata": {},
   "outputs": [],
   "source": [
    "PARENT = f\"projects/{PROJECT}/locations/{REGION}\"\n",
    "    \n",
    "print(\"Project ID:\", PROJECT)\n",
    "print(\"Region:\", REGION)\n",
    "print(\"Bucket name:\", BUCKET)\n",
    "print(\"Service Account:\", SERVICE_ACCOUNT)\n",
    "print(\"Vertex API Parent URI:\", PARENT)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "377f66a3-7e5b-46b3-b8ea-cf38b5fe92d4",
   "metadata": {},
   "source": [
    "## Create Interactive TFX Context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2abcd9d2-cc09-4ba5-979e-91696db38d03",
   "metadata": {},
   "outputs": [],
   "source": [
    "REMOVE_ARTIFACTS = True\n",
    "\n",
    "if tf.io.gfile.exists(ARTIFACT_STORE) and REMOVE_ARTIFACTS:\n",
    "    print(\"Removing previous artifacts...\")\n",
    "    tf.io.gfile.rmtree(ARTIFACT_STORE)\n",
    "    \n",
    "if tf.io.gfile.exists(MLMD_SQLLITE) and REMOVE_ARTIFACTS:\n",
    "    print(\"Deleting previous mlmd.sqllite...\")\n",
    "    tf.io.gfile.rmtree(MLMD_SQLLITE)\n",
    "    \n",
    "print(f'Pipeline artifacts directory: {PIPELINE_ROOT}')\n",
    "print(f'Local metadata SQLlit path: {MLMD_SQLLITE}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "751e4415-f6d8-48fe-b586-d2f15eafc5de",
   "metadata": {},
   "outputs": [],
   "source": [
    "connection_config = metadata_store_pb2.ConnectionConfig()\n",
    "connection_config.sqlite.filename_uri = MLMD_SQLLITE\n",
    "connection_config.sqlite.connection_mode = 3 # READWRITE_OPENCREATE\n",
    "mlmd_store = mlmd.metadata_store.MetadataStore(connection_config)\n",
    "\n",
    "context = InteractiveContext(\n",
    "  pipeline_name=PIPELINE_NAME,\n",
    "  pipeline_root=PIPELINE_ROOT,\n",
    "  metadata_connection_config=connection_config\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b17c3103-8aec-4bfb-87ea-6f4d52f4aa37",
   "metadata": {},
   "source": [
    "### Pipeline step 1: Hyperparameter generation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2fdcfe55-01fe-47fb-ae39-ae02ba2c761d",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "batch_size = 512\n",
    "\n",
    "hyperparams_gen = components.hyperparameters_gen(\n",
    "    num_epochs=5,\n",
    "    learning_rate=0.001,\n",
    "    batch_size=batch_size,\n",
    "    hidden_units='64,64',\n",
    "    steps_per_epoch=LIMIT // batch_size\n",
    ")\n",
    "\n",
    "context.run(hyperparams_gen, enable_cache=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "abb8af0d-2e85-449b-bab7-1b9da8f17caf",
   "metadata": {},
   "source": [
    "#### Load the output of the component from Cloud Storage to check"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "637629a3-e72d-40cf-8753-9e0cf7ed997a",
   "metadata": {},
   "outputs": [],
   "source": [
    "gcs_uri_ouput = hyperparams_gen.outputs['hyperparameters'].get()[0].uri\n",
    "gcs_uri_ouput"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "274b6774-220e-4e90-86b4-4b9aec9341a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "json.load(\n",
    "    tf.io.gfile.GFile(\n",
    "        os.path.join(gcs_uri_ouput, 'hyperparameters.json')\n",
    "    )\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4ae3bd61-080c-4af0-a88d-f58d9485508a",
   "metadata": {},
   "source": [
    "### Pipeline Step 2: Extract data from BQ onto Cloud Storage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39997b04-a2a3-4490-88de-4e218376de3e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "def sql_query(ml_use, limit=None):\n",
    "    return datasource_utils.get_training_source_query(PROJECT, REGION, VERTEX_DATASET_NAME, ml_use=ml_use, limit=limit)\n",
    "\n",
    "def output_config(splits):\n",
    "    return example_gen_pb2.Output(\n",
    "        split_config=example_gen_pb2.SplitConfig(\n",
    "            splits=[example_gen_pb2.SplitConfig.Split(name=split_name, hash_buckets=buckets) for (split_name, buckets) in splits]\n",
    "        )\n",
    "    )\n",
    "\n",
    "train_example_gen = BigQueryExampleGen(query=sql_query('UNASSIGNED', LIMIT), output_config=output_config([('train', 4), ('eval', 1)]))\n",
    "\n",
    "beam_pipeline_args=[\n",
    "    f\"--project={PROJECT}\",\n",
    "    f\"--temp_location={os.path.join(WORKSPACE, 'tmp')}\"\n",
    "]\n",
    "\n",
    "context.run(\n",
    "    train_example_gen,\n",
    "    beam_pipeline_args=beam_pipeline_args,\n",
    "    enable_cache=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "7c83709a-e59f-405e-8ab9-6185feac6c89",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "test_example_gen = BigQueryExampleGen(query=sql_query('TEST'), output_config=output_config([('test', 1)]))\n",
    "\n",
    "context.run(\n",
    "    test_example_gen,\n",
    "    beam_pipeline_args=beam_pipeline_args,\n",
    "    enable_cache=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "10f80eb6-13ea-4550-8ecb-069364f4538a",
   "metadata": {},
   "source": [
    "#### Read some TFRecords from the training data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d417fc7f-5a9e-4214-b01c-8305942db796",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9053e250-45f5-40b5-8d5b-7469a47c0c4d",
   "metadata": {},
   "outputs": [],
   "source": [
    "train_uri = os.path.join(train_example_gen.outputs['examples'].get()[0].uri, \"Split-train/*\")\n",
    "\n",
    "source_raw_schema = tfdv.load_schema_text(os.path.join(RAW_SCHEMA_DIR, 'schema.pbtxt'))\n",
    "raw_feature_spec = schema_utils.schema_as_feature_spec(source_raw_schema).feature_spec"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fff0ec27-c766-47fa-87f9-a21461029806",
   "metadata": {},
   "outputs": [],
   "source": [
    "def _parse_tf_example(tfrecord):\n",
    "    return tf.io.parse_single_example(tfrecord, raw_feature_spec)\n",
    "\n",
    "tfrecord_filenames = tf.data.Dataset.list_files(train_uri)\n",
    "dataset = tf.data.TFRecordDataset(tfrecord_filenames, compression_type=\"GZIP\")\n",
    "dataset = dataset.map(_parse_tf_example)\n",
    "\n",
    "for raw_features in dataset.shuffle(1000).batch(3).take(1):\n",
    "    for key in raw_features:\n",
    "        print(f\"{key}: {np.squeeze(raw_features[key], -1)}\")\n",
    "    print(\"\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d98b8cd-d014-4b7a-b2ca-9f1e3400bcf5",
   "metadata": {},
   "source": [
    "### Pipeline step 3: Data Validation\n",
    "\n",
    "Import the schema, generate statistics and validate the statistics against the schema."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2146edf3-01a3-4c65-80ba-d76f1604e5ad",
   "metadata": {},
   "outputs": [],
   "source": [
    "schema_importer = tfx.dsl.Importer(\n",
    "    source_uri=RAW_SCHEMA_DIR,\n",
    "    artifact_type=tfx.types.standard_artifacts.Schema,\n",
    "    reimport=False\n",
    ")\n",
    "\n",
    "context.run(schema_importer)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1847ecfa-4656-4efd-a234-657c78ec2968",
   "metadata": {},
   "source": [
    "Generate statistics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "385fecfd-d582-4ff1-b0f7-9266f93d91cb",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "statistics_gen = tfx.components.StatisticsGen(\n",
    "    examples=train_example_gen.outputs['examples'])\n",
    "context.run(statistics_gen)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "52b8b551-e8cf-4dfd-8b0e-75923859687d",
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf {RAW_SCHEMA_DIR}/.ipynb_checkpoints/"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1787919-0359-4003-b26a-d39c755b41f2",
   "metadata": {},
   "source": [
    "Validate statistics against schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "503235ce-b751-4d9e-9968-af7d86dfa20e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "example_validator = tfx.components.ExampleValidator(\n",
    "    statistics=statistics_gen.outputs['statistics'],\n",
    "    schema=schema_importer.outputs['result'],\n",
    ")\n",
    "\n",
    "context.run(example_validator)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "bda0fdfc-b146-4224-b55d-1a904e374ac1",
   "metadata": {},
   "outputs": [],
   "source": [
    "context.show(example_validator.outputs['anomalies'])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ce1eea1c-4ad0-4ec2-8a0b-155aee16b92e",
   "metadata": {},
   "source": [
    "### Pipeline Step 4: Data Preprocesing using TFX Transform (TFT)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c67235de-ce78-430b-8214-648e83789b94",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_transform_module_file = 'src/preprocessing/transformations.py'\n",
    "\n",
    "transform = tfx.components.Transform(\n",
    "    examples=train_example_gen.outputs['examples'],\n",
    "    schema=schema_importer.outputs['result'],\n",
    "    module_file=_transform_module_file,\n",
    "    splits_config=transform_pb2.SplitsConfig(\n",
    "        analyze=['train'], transform=['train', 'eval']),\n",
    ")\n",
    "\n",
    "context.run(transform, enable_cache=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7f82f156-ed6f-40d0-bde1-0acf0fb0b1b8",
   "metadata": {},
   "source": [
    "#### Test: Read an example of the transformed data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a9204ecf-e700-4b5b-81b3-32b18e9eed56",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "transformed_train_uri = os.path.join(transform.outputs['transformed_examples'].get()[0].uri, \"Split-train/*\")\n",
    "transform_graph_uri = transform.outputs['transform_graph'].get()[0].uri\n",
    "\n",
    "tft_output = tft.TFTransformOutput(transform_graph_uri)\n",
    "transform_feature_spec = tft_output.transformed_feature_spec()\n",
    "\n",
    "for input_features, target in data.get_dataset(\n",
    "    transformed_train_uri, transform_feature_spec, batch_size=3, epochs=1).take(1):\n",
    "    for key in input_features:\n",
    "        print(f\"{key} ({input_features[key].dtype}): {input_features[key].numpy().tolist()}\")\n",
    "    print(f\"target: {target.numpy().tolist()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "43191d8d-1689-4812-a07c-4811306ac112",
   "metadata": {},
   "source": [
    "### Pipeline Step 5: Model Training"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6f7fb3d7-a386-4b5f-aa55-c23b94a305b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tfx.dsl.components.common.resolver import Resolver\n",
    "from tfx.dsl.experimental import latest_blessed_model_resolver"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "aa33f320-fe92-4ee7-99ac-9486b0c9510e",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "_train_module_file = 'src/model_training/runner.py'\n",
    "\n",
    "trainer = tfx.components.Trainer(\n",
    "    module_file=_train_module_file,\n",
    "    examples=transform.outputs['transformed_examples'],\n",
    "    schema=schema_importer.outputs['result'],\n",
    "    transform_graph=transform.outputs['transform_graph'],\n",
    "    hyperparameters=hyperparams_gen.outputs['hyperparameters'],\n",
    ")\n",
    "\n",
    "context.run(trainer, enable_cache=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cce6e6f9-6ca1-48cc-8b72-7a0e8e37c856",
   "metadata": {},
   "source": [
    "### Pipeline Step 6: Model Evaluation\n",
    "\n",
    "#### Get the latest blessed model for model validation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b543c1d7-37f1-49f6-b29d-518b4ff3f1d0",
   "metadata": {},
   "outputs": [],
   "source": [
    "blessed_model_resolver = Resolver(\n",
    "    strategy_class=latest_blessed_model_resolver.LatestBlessedModelResolver,\n",
    "    model=tfx.dsl.Channel(type=tfx.types.standard_artifacts.Model),\n",
    "    model_blessing=tfx.dsl.Channel(type=tfx.types.standard_artifacts.ModelBlessing)\n",
    ")\n",
    "\n",
    "context.run(blessed_model_resolver, enable_cache=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "64604298-be89-4c3e-9677-c292fdfadb43",
   "metadata": {},
   "source": [
    "#### Evaluate the model and compare against the baseline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "2085e54a-1a15-437e-bd70-9146d0ecebd9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from tfx.components import Evaluator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6b9c9ac9-46f3-45ed-8750-f18fd15ab106",
   "metadata": {},
   "outputs": [],
   "source": [
    "eval_config = tfma.EvalConfig(\n",
    "    model_specs=[\n",
    "        tfma.ModelSpec(\n",
    "            signature_name='serving_tf_example',\n",
    "            label_key=features.TARGET_FEATURE_NAME,\n",
    "            prediction_key='probabilities')\n",
    "    ],\n",
    "    slicing_specs=[\n",
    "        tfma.SlicingSpec(),\n",
    "    ],\n",
    "    metrics_specs=[\n",
    "        tfma.MetricsSpec(\n",
    "            metrics=[   \n",
    "                tfma.MetricConfig(class_name='ExampleCount'),\n",
    "                tfma.MetricConfig(\n",
    "                    class_name='BinaryAccuracy',\n",
    "                    threshold=tfma.MetricThreshold(\n",
    "                        value_threshold=tfma.GenericValueThreshold(\n",
    "                            lower_bound={'value': 0.1}), ## note setting a very low barrier for this example\n",
    "                        # Change threshold will be ignored if there is no\n",
    "                        # baseline model resolved from MLMD (first run).\n",
    "                        change_threshold=tfma.GenericChangeThreshold(\n",
    "                            direction=tfma.MetricDirection.HIGHER_IS_BETTER,\n",
    "                            absolute={'value': -1e-10}))),\n",
    "        ])\n",
    "    ])\n",
    "\n",
    "\n",
    "evaluator = Evaluator(\n",
    "    examples=test_example_gen.outputs['examples'],\n",
    "    example_splits=['test'],\n",
    "    model=trainer.outputs['model'],\n",
    "    baseline_model=blessed_model_resolver.outputs['model'],\n",
    "    eval_config=eval_config,\n",
    "    schema=schema_importer.outputs['result']\n",
    ")\n",
    "\n",
    "context.run(evaluator, enable_cache=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d284e70-2185-4b71-8f23-cfaef4ec52d6",
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluation_results = evaluator.outputs['evaluation'].get()[0].uri\n",
    "print(\"validation_ok:\", tfma.load_validation_result(evaluation_results).validation_ok, '\\n')\n",
    "\n",
    "for entry in list(tfma.load_metrics(evaluation_results))[0].metric_keys_and_values:\n",
    "    value = entry.value.double_value.value\n",
    "    if value:\n",
    "        print(entry.key.name, \":\", round(entry.value.double_value.value, 3))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "975e2cfd-8524-4c25-88a7-fbee8b215352",
   "metadata": {},
   "source": [
    "### Pipeline Step 7: Push model to Cloud Storage"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "073bfac2-ee38-4b05-9642-1e29fbdcc7ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "exported_model_location = os.path.join(MODEL_REGISTRY, MODEL_DISPLAY_NAME)\n",
    "\n",
    "push_destination=tfx.proto.PushDestination(\n",
    "    filesystem=tfx.proto.PushDestination.Filesystem(\n",
    "        base_directory=exported_model_location,\n",
    "    )\n",
    ")\n",
    "\n",
    "pusher = tfx.components.Pusher(\n",
    "    model=trainer.outputs['model'],\n",
    "    model_blessing=evaluator.outputs['blessing'],\n",
    "    push_destination=push_destination\n",
    ")\n",
    "\n",
    "context.run(pusher, enable_cache=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c6e8dbe5-f3fa-4ce6-8057-d19a672350e2",
   "metadata": {},
   "source": [
    "### Pipeline Step 8: Upload model to Vertex AI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9d8c0686-f971-4b60-9db1-6036052c2098",
   "metadata": {},
   "outputs": [],
   "source": [
    "serving_runtime = 'tf2-cpu.2-5'\n",
    "serving_image_uri = f\"europe-docker.pkg.dev/vertex-ai/prediction/{serving_runtime}:latest\"\n",
    "\n",
    "labels = {\n",
    "    'dataset_name': VERTEX_DATASET_NAME,\n",
    "    'pipeline_name': PIPELINE_NAME\n",
    "}\n",
    "labels = json.dumps(labels)\n",
    "\n",
    "vertex_model_uploader = components.vertex_model_uploader(\n",
    "    project=PROJECT,\n",
    "    region=REGION,\n",
    "    model_display_name=MODEL_DISPLAY_NAME,\n",
    "    pushed_model_location=exported_model_location,\n",
    "    serving_image_uri=serving_image_uri,\n",
    "    model_blessing=evaluator.outputs['blessing'],\n",
    "    explanation_config='',\n",
    "    labels=labels\n",
    ")\n",
    "\n",
    "context.run(vertex_model_uploader, enable_cache=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba3923d4-0048-482c-96f7-3b52782355df",
   "metadata": {},
   "outputs": [],
   "source": [
    "vertex_model_uploader.outputs['uploaded_model'].get()[0].uri"
   ]
  }
 ],
 "metadata": {
  "environment": {
   "kernel": "python3",
   "name": "tf2-gpu.2-8.m93",
   "type": "gcloud",
   "uri": "gcr.io/deeplearning-platform-release/tf2-gpu.2-8:m93"
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
